suppressMessages(library(readr))
suppressMessages(library(dplyr))
suppressMessages(library(ggplot2))
suppressMessages(library(vtreat))
suppressMessages(library(caret))
suppressMessages(library(FSelector))
suppressMessages(library(rpart))
suppressMessages(library(rpart.plot))
knitr::opts_chunk$set(cache=TRUE)
train.data <- read_csv("train.csv", progress=FALSE)
train.data <- train.data
#test.data <- read_csv("test.csv")
print("Using vtreat package to prepare data")
## [1] "Using vtreat package to prepare data"
treatments <- readRDS('treatments.RDS')
train <- prepare(treatments, train.data, pruneSig=1)
#test  <- prepare(treatments, test.data, pruneSig=1)
is.duplicated <- function(col1, col2){
    duplicated <- length(which(as.integer(col1 == col2)== 0 ))==0
    return(duplicated)
}

print("Checking for duplicated columns")
## [1] "Checking for duplicated columns"
find.duplicates <- function(dframe){
    duplicates <- c()
    for(i in 1:(length(colnames(dframe)) - 1)){
        for(j in (i + 1):length(colnames(dframe))){
            if(is.duplicated(dframe[,i], dframe[,j])){
                duplicates <- c(duplicates, colnames(dframe)[j])
            }
        }
    }
    return(duplicates)
}

duplicatedcols <- find.duplicates(train)
if(length(duplicatedcols) > 0){
    print(paste("Removing", length(duplicatedcols), "Duplicated columns"))
    train <- train[,!names(train) %in% duplicatedcols]
    #test <- test[,!names(test) %in% duplicatedcols]
}else{
    print("No duplicated columns found")
}
## [1] "Removing 769 Duplicated columns"
print("Checking for constants")
## [1] "Checking for constants"
get.constants <- function(dframe){
    consts <- c()
    for(i in 1:length(colnames(dframe))){
        if(length(unique(dframe[,i])) == 1){
            consts <- c(consts, colnames(dframe)[i])
        }
    }
    return(consts)
}
consts <- get.constants(train)

if(length(consts) > 0){
    print(paste("Captured", length(consts), "Constant cols now removing them"))
    train <- train[,!names(train) %in% consts]
    #test <- test[,!names(test) %in% consts]
}else{
    print("No constant columns")
}
## [1] "No constant columns"
par(mar=c(2,2,2,2), mfrow=c(1,1))
char_cols = names(train[, sapply(train, is.character)])

for(i in char_cols) {
#  tmp = train[[i]]
 # p <- qplot(factor(i), data=train, geom="bar", main=i) 
 # p  
  barplot(table(tmp), main=i)
}
non_char = names(train[, !sapply(train, is.character)])

train.unique.count=lapply(train[, non_char], function(x) length(unique(x)))
low_count_nums=names(unlist(train.unique.count[unlist(train.unique.count)<=10]))

par(mar=c(2,2,2,2),mfrow=c(1,1))
counter <- 1
for(i in low_count_nums) {
 # gg <- ggplot(train, aes(i)) + geom_bar()
 # ggsave(gg,filename=paste("numFactors-",counter,".png",sep=""))
 # counter <- counter + 1
 # print(gg)
  barplot(table(train[[i]], useNA='ifany'), main=i, col=c("red"))
}

non_chars <- setdiff(non_char, low_count_nums)
counter <- 1
par(mar=c(2,2,2,2),mfrow=c(1,2))
for(i in non_chars) {
  hist(train[[i]], main = i)
 # gg <- ggplot(train, aes(i)) + geom_boxplot()
 # counter <- counter + 1
 # ggsave(gg,filename=paste("numManyValues-",counter,".png",sep=""))
  boxplot(train[[i]], main=i)
}

weights <- information.gain(target~., train)
highest.gain <- weights[order(weights$attr_importance,decreasing=TRUE)[1:30],,drop=FALSE]
barplot(highest.gain$attr_importance, names.arg=rownames(highest.gain))

feature.names <- highest.gain[1:10,,drop=FALSE]
fol <- formula(paste("target ~",paste(rownames(feature.names),collapse="+")))

print("Spliting data ")
## [1] "Spliting data "
intrain <- createDataPartition(train$target, p=0.8, list=FALSE)

training.set <- train[intrain,]
validation.set <- train[-intrain,]
model <- rpart(fol, data=training.set, method="class")
rpart.plot(model, varlen=0, type=4)

predictions <- predict(model, validation.set, type="class")
accuracy    <- sum(predictions == validation.set$target) / length(validation.set$target)
precision   <- length(which(predictions == 1 & validation.set$target == 1)) / length(which(predictions == 1))
recall      <- length(which(predictions == 1 & validation.set$target == 1)) / length(which(validation.set$target == 1))
f1.score    <- (precision * recall) / (precision + recall)


print(paste("accuracy  : " ,accuracy))
## [1] "accuracy  :  0.802353044086774"
print(paste("precision : ", precision))
## [1] "precision :  0.811164181984851"
print(paste("recall    : ", recall))
## [1] "recall    :  0.96526581697095"
print(paste("f1 score  : ", f1.score))
## [1] "f1 score  :  0.440765500065539"